Global prototype distillation for heterogeneous federated learning

Federated learning is a distributed machine learning paradigm where the goal is to collaboratively train a high quality global model while private training data remains local over distributed clients. However, heterogenous data distribution over clients is severely challenging for federated learning system, which severely damage the quality of model. In order to address this challenge, we propose global prototype distillation (FedGPD) for heterogenous federated learning to improve performance of global model. The intuition is to use global class prototypes as knowledge to instruct local training on client side. Eventually, local objectives will be consistent with the global optima so that FedGPD learns an improved global model. Experiments show that FedGPD outperforms previous state-of-art methods by 0.22% ~1.28% in terms of average accuracy on representative benchmark datasets.

www.nature.com/scientificreports/The motivation of the algorithm is coming from a common observation in deep learning: the model trained on the overall dataset can extract a better feature representation than the model trained on the biased sub-dataset.For example, the model trained on cat and dog images can extract much more cat's features than the model only trained on dog images.Apparently, in FL setting, the model trained on skewed local dataset would learn poor feature representations.As a consequence, each client's inconsistent local training will degrade the performance of the whole system.However, global model can be considered as a more powerful model trained on distributed data.Therefore, global model would extract more effective representations than local models.Based on above considerations, utilizing global knowledge from global model to correct each client's local training is the key to solving heterogenous data distribution.The intuition is that with the help of global knowledge, pulling the clients' local drift towards the corresponding global model is a feasible method, thus making the local training consistent with the global objective.Hence, the non-IID problem would be tackled.
Prototype learning 19 is introduced into FL to learn an effective global representation by fully using global model information.A prototype is calculated as the mean of the feature vectors within each class, which is a good representation of input data on corresponding class.Therefore, every client can compute class prototypes on its own dataset to represent local information.During communication period, clients not only send model parameters but also local prototypes to servers for aggregation.After gathering the class prototypes from each client, the server can calculate the global prototype.It can be regarded as the global knowledge and delivered back to clients.After receiving global prototype, clients use the proposed local objective to regulate the local training.With the assistance of knowledge distillation 20 , client uses the global prototype on its own classes as transferred knowledge to rectify local loss function, in order to achieve consistence with the global optima.The advantage of knowledge distillation on client side is to avoid the problem of public proxy dataset.
FedGPD promotes each instance of local dataset to approach the global prototype of its corresponding class.Hence, the performance of local model would be improved.It tackles heterogeneity problem from a new combination between prototype learning and knowledge distillation.The overall architecture is shown in Fig. 1.The specific steps can be described as follows: Firstly, the server creates a global model and sends all of its parameters to every local client.Secondly, the local terminal is used to train the client to obtain the parameters of the local model.Thirdly, the local client transmits global model parameters w t and global prototype C k to the server.Finally, after obtaining these two parameters, the server aggregates global model parameters w t and global prototype C k respectively.Meanwhile, the server sends the updated global model w t+1 to each local client again. Contributions: (1) We propose a novel federated learning algorithm called FedGPD.In order to address data heterogeneous problem, the global prototype is introduced as distilled knowledge to regulate local training and to avoid public proxy dataset problem.Suppose there are N clients.Client i holds a local dataset D i .The general goal of FL is to learn a global model weight w over the dataset D = {D i } N i=1 , while the clients' raw data will never be communicated with others.The objective of FL is to optimize: where L i (w) = E (x,y)∼D i l i F i (w; x), y is local objective function for client i, F i (w; x) represents the local model and l i (•, •) is the loss function.
On the independent and identically distributed (IID) data setting, FL achieves superior performance.However, users' data from real-world is often non-iid distributed which causes heavy damage to the aggregated model.Therefore, statistical heterogeneity across clients is the most significant challenge for FL.There are quite some researches focusing on tackling non-iid problem in FL, which are mainly divided into two perspectives: (1) improve the aggregation algorithm, such as FedNova, FedMA; (2) stabilize the efficacy of local training, such as MOON, FedProx.
For global model aggregation, FedMA uses Bayesian non-parametric methods to compare and takes weighted mean in a layer-wise way.FedNova normalizes the local updates before taking average.In 21 , q-Fair Federated Learning enables all participants to maximize the performance of their local model under a certain global model by modifying the aggregation algorithm.FedAvgM 22 proposes a mitigation strategy for the Federated Averaging algorithm via server momentum.FedMIX 23 utilizes Mean Augmented Federated Learning (MAFL), where clients send and receive averaged local data, to improve performance.FedBE 24 takes a Bayesian inference perspective to ensemble global models, leading to much robust aggregation.MHAT 25

Knowledge distillation
Knowledge Distillation (KD) is a novel technique to transfer knowledge from a complex but effective teacher model w T to a lightweight but less effective student model w S27-29 .It keeps student model as light as possible without losing performance.The intuition of KD is to minimize the divergence between logits outputs from teacher model and student model.Therefore, student model can learn the knowledge from teacher model and acquire a better performance because student model could approach output from teacher model 30,31 .The loss of knowledge distillation could be computed by Kullback-Leiler divergence: Based on the type of transferred knowledge, knowledge distillation can be divided into three categories: (1) output transfer; (2) feature transfer; (3) relation transfer.FedGPD belongs to the second category.The global model prototype is taken as knowledge to correct local training.
Moreover, the concept of KD has been brought into FL to deal with heterogeneity problem.FD 32 synchronizes logits per label which are accumulated during the local training.The averaged logits per label (over local steps and clients) will then be used as a distillation regularizer for the next round's local training.FedDF 33 treats local models as teacher models and evaluates each model on unlabeled public data stored on server side.Then, average logit outputs of teacher models are used as knowledge to train global student model.FedBE introduces Bayesian model ensemble into FL so that the ensemble predictions as the pseudo-labels of public unlabeled data are considered as teacher to train a student global model.These methods usually treat local models as teachers.
Then their knowledge will be transferred to a student global model.FedAUX 34 improves performance by deriving maximum utility from the unlabeled auxiliary data, where the ensemble predictions on the auxiliary data are weighted according to the certainty of each client model.
The huge drawback of these method is that they all need a public dataset or a proxy unlabeled dataset to complete knowledge distillation.In order to address the above problems, FedGPD uses global prototype as x transferred knowledge and knowledge distillation happens on client side.Thus, there is no requirement of any proxy dataset and additional training on server side.

Proposed global prototype distillation Federated learning
In this section, a summary of whole proposed algorithm is shown in Algorithm 1.An overview of its learning procedure is illustrated in Fig. 2. The general flow can be described as follows: Firstly, the data set X is fed into the neural network of the local client and projected by the MLP as z.Secondly, the neural network mapping result for the local client is logits.Then it is used for Softmax to compute classification probabilities.Finally, the local prototype obtained from the local client will be uploaded to the server once more.

Problem formulation
The considered heterogeneous FL problem is formulated as follows.Suppose there are N clients in the system.Each client possesses a labeled image dataset , where x i m is the m-th data sample in i-th client.y i m ∈ {1, 2, ..., K} is the corresponding label among K classes, and |D i | denotes the number of train- ing samples owned by the i-th client.The datasets of different clients may be drawn from different distributions p i , i = 1, ..., N .Dirichlet distribution Dir N (β) is used to simulate label distribution skew among clients, which is most common federated setting.The goal is to train models for the image classification task among the clients.In other words, under non-iid setting where clients' data distribution p i are quite different from each other, there is a requirement to optimize

Architecture of FedPGD
• Client local network design A typical neural network generally consists of three components: the base encoder extracts representation from inputs; the project header is used to map the representation to a new feature space; the output layer produces the classification decisions for each class.Furthermore, the base encoder and project header form representation layer; output layer is actually decision layer 14 .Representation layer maps the input data x from the original feature space to representation space.The extracting and mapping function is denoted by z = f r (w r ; x) , where z is mapped representation of x and w r is learnable parameters of representation layer.Decision layer makes the classification decision for specific learning tasks which generates a prediction from mapped representation z. s = f d (w d ; z) is to denote the prediction function, where s is output prediction and w d is learnable parameters of decision layer.Therefore, a complete network could be denoted as: Prototype is a proxy of one class in classification tasks.It can be measured as the mean value of the feature vectors in every class.For the i-th client, the local prototype C i,k is the mean representation of inputs in class k.
where D i,k is the subset of local dataset which consists of training data belonging to class k in client i.
Figure 2 represents the local network architecture.Notice that there is the similar network architecture as FedProto 35 , but this works uses different methods to solve non-iid problem.Logits output is used for conventional cross entropy loss.Besides, local prototype from representation layer is considered as an input of distillation term on local loss function.The design of the prototype regularization term in the local training loss function is different from FedProto.
• Whole algorithm design The intuition of FedGPD is from knowledge distillation and prototype learning.Under non-iid setting, local model would overfit its own biased local dataset.It results in a significant degradation of global model's effectiveness.However, aggregated global model is supposed to acquire a better performance than each local model.Global model has a superior ability to extract information from decentralized global dataset.Therefore, the global model could be considered as a teacher model to instruct each local model as a student model to avoid overfitting its own biased local dataset.However, directly using the logits from global model as knowledge could lose partial inputs feature information.Considering prototype is the representation of input data which means it contains more feature information, global prototype, as the representation of global dataset, is employed to instruct local model to better fit the whole data distribution.Clients will upload local prototypes together with model parameters to server for aggregation.
Global prototype after gathering local prototypes from n clients, the aggregated global prototype is computed by averaging them.C k is global prototype in class k, which is weighted mean of local prototypes corresponding to class k.Therefore, the overall global prototype set is {C k } K k=1 , where K is total numbers of labels.For better explanation, FedGPD is described in Algorithm 1.There are four main steps:

Global prototype distillation
Local Objective: the designed local loss function is comprised of two parts.The first part is a typical crossentropy loss term ℓ CE pervasively used in supervised learning task.The second part is proposed global prototype distillation loss term denoted as ℓ PGD .This term makes the local model learn global information from represen- tation layer so that the local prototypes will approach global prototypes.This work introduces hyperparameter to control the weight of ℓ PGD .The final loss function for the local model is: Global prototype distillation loss: in order to mitigate over-biased local model, this work brings knowledge distillation into local training phase.Utilizing the global prototype as teaching information, FedGPD formulates the distillation loss as: where C T k is softmax output of global prototype on temperature T in class k, C T k is softmax output of local prototype on temperature T in class k and K i is the total number of classes in client i.The output dimension of the representation layer decides the size of prototype, which affects the following knowledge distillation.Therefore, the output dimension is set as a hyperparameter under tuning.By default, 256 is a common choice followed by other algorithms for a fair comparison.
This work uses PyTorch to implement the algorithm and the other baselines.SGD optimizer with learning rate 0.01 is set for all algorithms.SGD weight decay and momentum are set to 0.00001 and 0.9 respectively.Training batch size is set to 64.The number of local training epochs is set to 10 by default and the number of communication rounds is set to 50 for MNIST and 100 for CIFAR-10/CIFAR-100.Distillation temperature T and global prototype distillation loss weight are tunable hyperparameters and are set T = 2 , = 0.05 respectively for default.
Following the previous works, Dirichlet distribution is utilized to generate the non-IID distribution among clients.Concretely, sample p k ∼ Dir N (β) and allocate a p k,i proportion of the instances of class k to client i, where β is a concentration parameter controlling the degree of non-IID distribution.The default configuration is list on Table 1.

Accuracy results
Table 2 lists the test accuracy of FedGPD and other baselines.Under heterogeneous setting, SOLO shows the worst result among all the methods, which proves the benefits of federated learning.FedAVG, as a basic algorithm for FL, uses cross-entropy to train the local models and weighted average to aggregate model parameters.But it achieves relatively low accuracy under non-IID setting.Moreover, FedProx makes little modification on the FedAVG by adding a proximal term on loss function.As a result, its accuracy is very close to FedAVG, especially when parameter µ is small.Furthermore, MOON proposes a model-level contrastive federated learning with model-contrastive loss to deal with heterogeneous problem.MOON achieves great performance on slightly heterogeneous setting and even better than FedGPD under the condition of β = 1 .However concentration parameter β is 0.1, which means non-iid problem is severe and is more consistent with the practical situation, FedGPD provides better results.According to Figs. 3 and 4, FedGPD outperforms FedProx on MNIST dataset and also acquires higher test accuracy than MOON on CIFAR10 dataset.Therefore, FedGPD achieves competitive or even better performance than other methods on the different datasets.It demonstrates that FedGPD can effectively rectify the local training by using global prototype distillation.

Comparison with logits distillation method
The concept of original KD is to transfer the knowledge via minimizing the KL-Divergence between prediction logits of teachers and students.while this work chooses feature-based knowledge distillation instead of logits-base knowledge distillation.To demonstrate the superiority of global prototype distillation method, two additional distillation methods are designed for comparison: First method is directly to use the previous round global model logits output as distilled knowledge, denoted as method 1.Specifically, distillation loss in method 1 is computed as: ℓ method1 = KL(p||q) , where p = [p 1 , p2 , ..., pK ] .pk is computed by , where z k represents the logit of the k-th class from previous global model.q is logits from local model.
Second method is to aggregate local model logits to form global logits as distilled knowledge, denoted as method 2. Specifically, distillation loss in method 2 is computed as:ℓ method2 = KL(p||q) , where p = [p 1 , p2 , ..., pK ] .pk is computed by average all local model logits for k-th class.q is logits from local model.Table 3 lists the test accuracy of FedGPD and other baselines.It can be seen that the accuracy of our FedGPD is 0.27% ~3.08% higher than the other two comparison schemes.FedGPD performs better than other two distillation methods because global prototypes extract more useful information than logits output.Therefore, the novel global prototype distillation method of FedGPD transfers more effective instruction knowledge to rectify local training.

Influence of number of local epochs
Through the study of influence of number of local epochs (E), it is found that there is a trade-off between local epochs and model performance.When E is set to 1, the local update is very small so the local network cannot be well trained.The test accuracy is relatively low.However, when E is too large, local model would overfit the skewed local dataset, which leads to a degradation of test accuracy.Table 4 shows that number of local epochs is set to 10, which turns out a suitable choice.

Influence of data heterogeneity
It is important to investigate the influence of data heterogeneity by changing the concentration parameter β of Dirichlet distribution to evaluate the performance of algorithms.For a smaller β , the data distribution will be more skewed.The results are listed on Table 5.Although MOON gets better test accuracy when β = 1 , FedGPD achieves better performance when β = 0.1 or 0.5.This demonstrates that global prototype distillation is more effective under extremely heterogenous condition.

Influence of coefficient in loss function
The two important coefficients in ℓ GPD are distillation temperature T and weight hyperparameter .Distillation temperature T controls the importance of each soft target.As the temperature increases, the distribution of soft target will be smoother.Weight hyperparameter controls the proportion of ℓ GPD in whole loss function.This works tries T from 2, 5 and from 0.05, 0.1, 0.5, 1 on CIFAR10 dataset, shown on the Table 6.It has been found out that T = 2 is a common choice even when β is 0.1.But for weight hyperparameter , when β is 0.5, the choice of slightly influences test accuracy especially = 0.05 or 0.5.Therefore, according to Fig.     7. FedGPD achieves equivalent and even better performance.Although there is a global model to aggregate model parameters, the effectiveness of FedGPD has been proven.With correction term of Global prototype distillation, local models can effectively handle its respective heterogeneous datasets, and the global model also performs well on the entire test dataset.

Communication cost and limitation
Communication costs are always a challenge in FL.In general, communication costs in most FL algorithms are decided by size of shared model parameters.For instance, in FedAVG, clients and server have the same network architecture.Therefore, size of model parameter will affect the communication cost.Take CIFAR10 as an example, simple CNN model in this work contains two 5*5 convolution layers followed by 2*2 max pooling and two fully connected layers with ReLU activation.The total number of parameters is 85486.In FedGPD, except for model parameters, prototypes are needed to be transmitted between server and clients.However, compared to model parameters, the size of prototype is relatively small.Default prototype size for each class is 256.The extra communication cost is 2560 at most, which is negligible.Therefore, the proposed global prototype distillation method can outperform most of the state-of-the-art FL schemes within a slight communication overhead.
One possible limitation is that sending local prototypes as well as model parameters would put local dataset in the threat of privacy leakage.However the focus of this work is dealing with heterogenous data distribution problem in FL.Local prototype can be transmitted via lightweight encryption because prototype size of each class is 256, which is relatively small compared to model parameters.Therefore, security issue can totally be handled by applying some lightweight encryption algorithm.
introduces a novel model-heterogenous aggregation training federated learning scheme to extract the update information of the heterogenous model of all clients.Fortunately, FedGPD is orthogonal to above methods.The combination of these works is promising because FedGPD focuses on local training phase instead of model aggregation.For stabilization of local training, FedProx proposes a proximal term to optimize local objective.The distance between global model and local model is computed to restrict local model parameters to approach global optimal instead of overfitting local dataset.SCAFFOLD 26 introduces control variates to correct the drift in local training phase.MOON takes advantage of model-level contrastive learning to decrease the distance between global model representation and local model representation.It increases the distance between local model representation and previous local model representation.

Figure 2 .
Figure 2. Overview of the local network architecture.
(1) Server initializes a random global model and sends the parameters to each client who participates in training.(2) each client updates the local model according to proposed loss function after receiving the global model.(3) clients upload the local model parameters and local prototypes to server for aggregation.(4) server aggregates local model parameters and local prototypes to produce a global model and global prototype respectively for next round.These four steps will stop until achieving convergence or reaching maximum communication rounds.

Figure 3 .
Figure 3. Test accuracy with different number of communication rounds on CIFAR10.

Figure 4 .
Figure 4. Test accuracy with different number of communication rounds on MNIST.

5 T
= 2 (%) T = 3 (%) T = 2 (%) T = 5 (%) = 0.05 Experiments are conducted to mainly compare FedGPD with three state-of-art federated learning algorithms, including (1) FedAvg, (2) FedProx, (3) MOON,and SOLO, where each client trains its own model on local dataset without federated learning.Experiments datasets are three common open source datasets: MNIST, CIFAR-10, CIFAR-100.To be fair, the same model are used in local training for all algorithms.Two different neural networks are applied: (1) a simple CNN model plus a fully connected layer as representation layer for MNIST and CIFAR-10.(2) ResNet-18 plus a fully connected layer as representation layer for CIFAR-100.Specifically, the simple CNN model contains two 5*5 convolution layers followed by 2*2 max pooling and two fully connected layers with ReLU activation.

Table 1 .
The default configuration.

Table 2 .
Test accuracy of FedGPD and the other methods on test dataset.
Comparison with FedProtoFedProto conducts MSE to measure the distance between local prototype and global prototype, and uses n-way k-shoot to divide datastet.Therefore, not only data classes of train dataset and test dataset are the same on one client, but also there is no global model to aggregate, where local models are more personalized and whole settings are totally different.If directly using FedProto into Dirichlet distribution setting, local model won't be able to gain ability of global model to deal with the whole test dataset because each local model is tested on its own distributed partial test dataset.Experiments with n-way k-shoot could not simulate the longtail problem in reality.However, for fair comparison with FedProto, following their experiment setting (dividing dataset by 3-way 100-shoot), global prototype distillation is used to rectify local model.
5, T = 2 and = 0.05 are set as the default values.

Table 3 .
Test accuracy comparison with other KD methods.

Table 4 .
Test accuracy of FedGPD with different number of local epochs.Significant values are in bold.

Table 5 .
Test accuracy of FedGPD and the other methods on test dataset.Significant values are in bold.Method β = 0.1 (%) β = 0.5 (%) β = 1 (%) Experiments evaluate local model performance instead of global model.Results are shown in the Table

Table 7 .
Average test accuracy of local models compared with FedProto.